-
Notifications
You must be signed in to change notification settings - Fork 295
Make Decoding Functions Graph-compatible (with XLA Support!) #271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Decoding Functions Graph-compatible (with XLA Support!) #271
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Left a few comments
|
||
expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) | ||
expected_outputs = tf.concat([inputs, expected_outputs], axis=1) | ||
self.assertEqual( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would just using mode.predict work? that would still hit all the compiled function paths, and allow you to avoid all this dummy metric stuff, which is hard to read
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could also be good to test the call on a batched dataset (where batch size not statically known), and on a single constant input, as you are doing here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, man. Stupid me. Should have used model.predict
:P
keras_nlp/utils/text_generation.py
Outdated
), | ||
body=one_step, | ||
loop_vars=[prompt], | ||
shape_invariants=[tf.TensorShape(shape_invariants)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just pass tf.TensorShape([None, None])
as the shape invariant? Generally we should support a static batch size of None
, tf data does this by default after calling .batch()
for example. Might simplify the code a bit.
|
||
inputs = tf.constant([[0, 1], [1, 2]]) | ||
model = TestModel() | ||
model.compile(metrics=[dummy_metric]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you add jit_compile=True
does the test still pass?
If yes, we should test this with both jit_compile=True and False, using https://docs.pytest.org/en/6.2.x/parametrize.html
If no, we should either try to fix things with jit_compilation, or make sure we track that on a follow up issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@abheesht17 Let's try add a test case for jit_compile=True
, and we can run it on GPU. We recently add GPU test support in this repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not working with jit_compile = True
. Complete error logs: https://p.ip.fi/2TNt.
Looks like it won't work with shape_invariants
.
Also re-beam search, separate PR sounds good! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mattdangerw, thanks for the review! Addressed all comments, save the jit_compile
one.
|
||
expected_outputs = tf.tile([[3], [0]], [1, max_length - 2]) | ||
expected_outputs = tf.concat([inputs, expected_outputs], axis=1) | ||
self.assertEqual( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, man. Stupid me. Should have used model.predict
:P
|
||
inputs = tf.constant([[0, 1], [1, 2]]) | ||
model = TestModel() | ||
model.compile(metrics=[dummy_metric]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not working with jit_compile = True
. Complete error logs: https://p.ip.fi/2TNt.
Looks like it won't work with shape_invariants
.
/gcbrun |
I think a pull request went by recently where we stopped doing seeded random generation because of discrepancies. Is this safe to land as is @chenmoneygithub @jessechancy ? |
Seeded random generation should be removed. This is mainly because even when fully seeded, the randomness output is different on accelerator-testing with GPU. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great. Just leaving some quick initial comments.
keras_nlp/utils/text_generation.py
Outdated
tf.cast(max_length, dtype=tf.int64), | ||
), | ||
body=one_step, | ||
loop_vars=[state], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we avoid the state dict and just do loops_vars=(length, prompt)
here? might be a little more readable
keras_nlp/utils/text_generation.py
Outdated
|
||
# Pad the prompt with `pad_token_id` to `max_length`. We use `map_fn` here | ||
# because the batch_size might not be static. | ||
prompt = tf.map_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we should be able to make this simpler, we are just padding a batched tensor with pad_token_id to the sequence length right? We should not need a map_fn for this
loop_vars=[state], | ||
)[0] | ||
|
||
prompt = state["prompt"] | ||
if end_token_id is not None: | ||
prompt = mask_tokens_after_end_token( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we even need this function anymore, if we are just starting with the correct sized tensor filled with pad_token_id?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm I guess we do to avoid random tokens after the end_token_id
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments, but this looks pretty good to me! I only commented on one of the four utilities, but comments apply to all.
keras_nlp/utils/text_generation.py
Outdated
length = prompt.shape.as_list()[1] | ||
|
||
# Pad the prompt with `pad_token_id` to `max_length`. | ||
prompt = tf.concat( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe split this into two lines for readability?
padding = tf.fill((tf.shape(prompt)[0], max_length - length), pad_token_id)
prompt = tf.concat((prompt, padding), axis=-1)
keras_nlp/utils/text_generation.py
Outdated
while i < max_length: | ||
# If the prompt has reached our desired length, exit while loop. | ||
pred = token_probability_fn(prompt) | ||
length = prompt.shape.as_list()[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just do something like
batch_size, length = tf.shape(x)
And use that below? Then length and batch size are both tensors from the start.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, I'll split this into two lines:
batch_size = tf.shape(prompt)[0]
length = tf.shape(prompt)[1]
batch_size, length = tf.shape(x)
does not work in graph mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stack trace: https://p.ip.fi/6YAg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah destructuring is too fancy for autograph, I forgot. let's do
shape = tf.shape(prompt)
batch_size = shape[0]
length = shape[1]
keras_nlp/utils/text_generation.py
Outdated
return (length, prompt) | ||
|
||
# Run a while loop till text of length `max_length` has been generated. | ||
prompt = tf.while_loop( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
length, prompt = tf.while_loop(...)
just to avoid that [1]
which is not super readable
|
||
class TestModel(tf.keras.Model): | ||
def call(self, inputs, training=False): | ||
if not training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason you have to do training switch here? it looks like you are never actually testing the training=True branch, might be nice to clean up the test a bit
@chenmoneygithub do you know why the accelerator testing is failing here? This would be a great one to actually test on accelerators. |
I found it out, it's because the git branch has not synced to master branch, so the build file is outdated. @abheesht17 Could you sync and push again? Thanks! |
Sure! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! Dropped a comment on the test.
Also could you help create a TODO(chenmoneygithub)
at the top of text_generation.py saying we should refactor the code to reuse the same code? The padding + scatter_update handling is more complex than before, so it might be nice we can reuse the code.
@@ -342,7 +406,7 @@ def test_generate_with_ragged_prompt(self): | |||
def test_assert_probability_distribution_generation_is_correct(self): | |||
def token_probability_fn(inputs): | |||
batch_size = inputs.shape[0] | |||
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]]) | |||
prob = tf.constant([[0.0, 0.0, 0.0, 1.0]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to change the number here? The original value seems to be more general?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yes. This was done to take care of accelerator testing. Seeded generation does not work, so, we've made the probability 1.
Resolves #241
Partially resolves #277
Will have to think a bit more about Beam Search.